跳到主要内容

高级排序算法 归并排序

归并排序的介绍

归并排序(MERGE- SORT)是利用归并的思想实现的排序方法,该算法采用经典的分治(divide -and conquer)策略,分治法将问题分(divide)成一些小的问题然后递归求解,而治(conquer)的阶段则将分的阶段得到的各答案 “修补” 在一起,即分而治之。

归并排序的时间复杂度为 O(nlogn)O(nlogn)

归并排序是怎么工作的?

就是分治策略

上图的上半部分就是分的过程(divide),而下半部分就是治(conquer),将分的阶段得到的各答案 “修补” 在一起

注意:其实分的过程没有做什么实际性的工作,主要是后面治的过程,治的过程看下面代码的 merge 函数

这里治(conquer)的过程实际就是插入法(注意,它和插入排序还是有区别的,因为它已经排好序了,所以只需简单的比对两个指针 p1、p2 指向的节点谁比较大就行了),不过它通过一个辅助数组来存储 “插入” 的中间结果,等插完了再把辅助数组中的元素拷贝到原数组里面,如下图所示 p1、p2 指针依次比较所在位置的值的大小,把较小的那个填入辅助数组中,然后较小的指针向后移动一位,重复比较的过程

归并排序编写

归并排序的具体步骤看下面注释

public class Merge {

private static Comparable[] assist;

public static <T extends Comparable<T>> void sort(T[] a) {
//1. 初始化辅助数组 assist
assist = new Comparable[a.length];
//2. 定义一个 lo变量,和 hi变量,分别记录数组中最小的索引和最大的索引
int lo = 0;
int hi = a.length - 1;
//3. 调用 sort 重载方法完成数组 a中,从索引 lo到索引 hi的元素排序
sort(a, lo, hi);
}

private static <T extends Comparable<T>> void sort(T[] a, int lo, int hi) {
// 先做安全性校验
if (lo >= hi) {
return;
}
// 对 lo到 hi之间的数据进行分为两个组
int mid = lo + (hi - lo) / 2;

// 再对每一组分别排序,分的过程
sort(a, lo, mid);
sort(a, mid + 1, hi);

// 最后对这两组数据进行归并(治)
merge(a, lo, mid, hi);
}

// 治(conquer)的过程
private static <T extends Comparable<T>> void merge(T[] a, int lo, int mid, int hi) {
// 定义三个指针
int i = lo;
int p1 = lo;
int p2 = mid + 1;

// 遍历,移动 p1指针和 p2指针,比较对应索引处的值,找出最小的那个放到辅助数组里面
while (p1 <= mid && p2 <= hi) {
// 比较索引处的值
if (less(a[p1], a[p2])) {
// 放到辅助数组里面,同时指针要移动(别忘了辅助数组的指针也要移动一位)
assist[i++] = a[p1++];
} else {
assist[i++] = a[p2++];
}
}

// 遍历,如果 p1指针还没走完,那么顺序移动 p1指针,把对应的元素放到辅助数组的对应索引处
while (p1 <= mid) {
assist[i++] = a[p1++];
}

// 遍历,如果 p2指针还没走完,那么顺序移动 p2指针,把对应的元素放到辅助数组的对应索引处
while (p2 <= hi) {
assist[i++] = a[p2++];
}

// 把辅助数组中的元素拷贝到原数组里面
for (int index = lo; index <= hi; index++) {
a[index] = (T) assist[index];
}
}

private static <T extends Comparable<T>> boolean less(T v, T w) {
return v.compareTo(w) < 0;
}
}

计算时间复杂度

  • a 代表 子问题是等规模的情况下,调用了多少次
  • 后面的 O(nd)O(n^d) 代表除去调用子过程之外,剩下过程的时间复杂度
  • n/bn/b 代表子问题的规模

可以看到调用了两次子问题,所以 a = 2,又因为每次问题的规模都缩小一半,所以 b = 2,而 O(nd)O(n^d) 在这里则是 O(n)O(n) (因为 merge 方法基本就是 O(n)O(n))。

取得 a、b、d 然后再根据下面这个图里的条件判断,最后可以得出右边的时间复杂度, d=logbad = log_ba

最后算的时间复杂度为:

O(nlogn)O(n * logn)

例题:求小合问题

问题:在一个数组中,每一个元素左边比当前元素值小的元素值累加起来,叫做这个数组的小和

例如:[2,3,4,1,5]

  • 2 左边比 2 小的元素:无
  • 3 左边比 3 小的元素:2
  • 4 左边比 4 小的元素:2,3
  • 1 左边比 1 小的元素:无
  • 5 左边比 5 小的元素:2,3,4,1

小和 small_sum = 2 + 2 + 3 + 2 + 3 + 4 + 1 = 17

解决思路:

把上面的过程转换一下,变成求右边比当前数大的值。

  • 2 右边比 2 大的元素:3,4,5 => 三个数
  • 3 右边比 3 大的元素:4,5 => 两个数
  • 4 右边比 4 大的元素:5 => 一个数
  • 1 右边比 1 大的元素:5 => 一个数
  • 5 右边比 5 大的元素:无 => 零

所以小和 small_sum = 3 2 + 2 3 + 4 + 1 = 17

所以这两个过程是等效的

然后 merge 的过程参考 左神这个视频 1:05:01 秒开始的过程

所以最终代码为:

public class SmallSum {

private static int[] assist;

public static int smallSum(int[] arr) {
//1. 初始化辅助数组 assist
assist = new int[a.length];
//2. 定义一个 lo变量,和 hi变量,分别记录数组中最小的索引和最大的索引
int lo = 0;
int hi = a.length - 1;
//3. 调用 sort 重载方法完成数组 a中,从索引 lo到索引 hi的元素排序
return process(a, lo, hi);
}

private static int process(int[] arr, int lo, int hi) {
// 先做安全性校验
if (lo >= hi) {
return 0;
}
// 对 lo到 hi之间的数据进行分为两个组
int mid = lo + (hi - lo) / 2;

// 再对每一组分别排序,分的过程
return process(arr, lo, mid) +
process(arr, mid + 1, hi) +
merge(arr, lo, mid, hi);
}

// 治(conquer)的过程
private static int merge(int[] arr, int lo, int mid, int hi) {
// 定义三个指针
int i = lo;
int p1 = lo;
int p2 = mid + 1;
int res = 0;

// 遍历,移动 p1指针和 p2指针,比较对应索引处的值,找出最小的那个放到辅助数组里面
while (p1 <= mid && p2 <= hi) {
// hi - p2 + 1 计算比 arr[p1] 大的元素数量
res += arr[p1] < arr[p2] ? (hi - p2 + 1) * arr[p1] : 0;
// 比较索引处的值
if (arr[p1] < arr[p2]) {
// 放到辅助数组里面,同时指针要移动(别忘了辅助数组的指针也要移动一位)
assist[i++] = arr[p1++];
} else {
assist[i++] = arr[p2++];
}
}

// 遍历,如果 p1指针还没走完,那么顺序移动 p1指针,把对应的元素放到辅助数组的对应索引处
while (p1 <= mid) {
assist[i++] = arr[p1++];
}

// 遍历,如果 p2指针还没走完,那么顺序移动 p2指针,把对应的元素放到辅助数组的对应索引处
while (p2 <= hi) {
assist[i++] = arr[p2++];
}

// 把辅助数组中的元素拷贝到原数组里面
for (int index = lo; index <= hi; index++) {
a[index] = assist[index];
}

return res;
}
}